package org.zjvis.datascience.common.etl;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.google.common.base.Joiner;
import com.google.common.collect.Lists;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.parser.SimpleNode;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.*;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.zjvis.datascience.common.constant.SqlTemplate;
import org.zjvis.datascience.common.enums.ETLEnum;
import org.zjvis.datascience.common.enums.SubTypeEnum;
import org.zjvis.datascience.common.exception.SqlParserException;
import org.zjvis.datascience.common.model.ApiResultCode;
import org.zjvis.datascience.common.sql.SqlHelper;
import org.zjvis.datascience.common.util.StringUtil;
import org.zjvis.datascience.common.util.ToolUtil;
import org.zjvis.datascience.common.vo.TaskVO;

import java.util.*;

/**
 * @description ETL-SQL 自定义SQL类
 * @date 2021-12-27
 */
public class Sql extends BaseETL {

    private final static Logger logger = LoggerFactory.getLogger("Sql");

    private String text;

    private static Set<String> boundSets = new HashSet<>();

    static {
        boundSets.add(" ");
        boundSets.add(",");
        boundSets.add("=");
        boundSets.add("!");
        boundSets.add("<");
        boundSets.add(">");
        boundSets.add("\n");
        boundSets.add("\t");
        boundSets.add("\r");
    }

    public Sql() {
        super(ETLEnum.SQL.name(), SubTypeEnum.ETL_OPERATE.getVal(),
                SubTypeEnum.ETL_OPERATE.getDesc());
    }

    public void parserConf(JSONObject conf) {
        text = conf.getString("sql");
    }

    private int getRightBoundIndex(String text, int fromIndex) {
        int length = text.length();
        int i = fromIndex;
        for (; i < length; ++i) {
            String t = text.substring(i, i + 1);
            if (boundSets.contains(t)) {
                return i;
            }
        }
        return i;
    }

    private String replaceTableName(long timeStamp, String modifySql) {
        if (StringUtils.isNotEmpty(modifySql)) {
            text = modifySql;
        }
        if (!text.contains("$")) {
            return text;
        }
        StringBuffer sb = new StringBuffer();
        int fromIndex = 0;
        do {
            int index = text.indexOf("$", fromIndex);
            if (index == -1) {
                sb.append(text.substring(fromIndex));
                break;
            } else {
                sb.append(text, fromIndex, index);
                int blankIndex = getRightBoundIndex(text, index);
                String tableName = text.substring(index + 1, blankIndex);
                if (tableName.startsWith(SqlTemplate.SCHEMA + ".") || tableName
                        .startsWith(SqlTemplate.SOURCE_SCHEMA + ".")) {
                    String[] parts = tableName.split("\\.");
                    if (parts.length == 2) {
                        tableName = ToolUtil.alignTableName(tableName, timeStamp);
                    } else {
                        int lastIndex = tableName.lastIndexOf(".");
                        tableName = ToolUtil
                                .alignTableName(tableName.substring(0, lastIndex), timeStamp);
                        tableName += "." + parts[parts.length - 1];
                    }
                } else {
                    int lastIndex = tableName.lastIndexOf(".");
                    if (lastIndex == -1) {
                        tableName = ToolUtil.alignTableName(tableName, timeStamp);
                    } else {
                        String[] parts = tableName.split("\\.");
                        tableName = ToolUtil
                                .alignTableName(tableName.substring(0, lastIndex), timeStamp);
                        tableName += "." + parts[parts.length - 1];
                    }
                }
                sb.append(tableName);
                fromIndex = blankIndex;
            }
        } while (fromIndex < text.length());
        return sb.toString();
    }

    private boolean compareTableName(Table target, String candidate) {
        String schemaName = target.getSchemaName();
        String name = target.getName();
        String[] parts = candidate.split("\\.");
        if (parts.length > 1) {
            return schemaName.equals(parts[0]) && name.equals(parts[1]);
        } else {
            return name.equals(parts[0]);
        }
    }

    private String getTempTableName(JSONObject inputItemObj, Long timeStamp) {
        if (inputItemObj.getString("tableName").endsWith("_")) {
            //如果是清洗节点 需要对 input 对象中拿到的tableName做修改
            return inputItemObj.getString("tableName") + "" + timeStamp;
        } else {
            return inputItemObj.getString("tableName");
        }
    }

    private static String extractColumnName(SelectExpressionItem item) {
        String columnName = "";
        SimpleNode node = item.getASTNode();
        Object value = node.jjtGetValue();
        if (value instanceof Column) {
            columnName = ((Column) value).getColumnName();
        } else if (value instanceof Function) {
            columnName = value.toString();
        } else {
            // 增加对select 'aaa' from table; 的支持
            columnName = String.valueOf(value);
            columnName = columnName.replace("'", "");
            columnName = columnName.replace("\"", "");
            columnName = columnName.replace("`", "");
        }
        return columnName;
    }

    /**
     * @param sql
     * @param outputColumns
     * @param input
     * @return
     */
    private String extractColumnsForSelectSql(String sql, List<String> outputColumns,
                                              JSONArray input, TaskVO vo, JSONArray parentTimeStamps) throws SqlParserException {
        sql = sql.trim().toLowerCase().replaceAll("\\$", "").replaceAll("，", ",");
        Statement statement = null;
        try {
            statement = CCJSqlParserUtil.parse(sql);
            if (!(statement instanceof Select)) {
                logger.error("going to execute sql {} is delete or drop", sql);
                throw new SqlParserException("目前只支持select语句");
            }
        } catch (JSQLParserException e) {
            vo.setException(new SqlParserException(e.getCause().getMessage()));
            e.printStackTrace();
            return sql;
        }
        Select select = (Select) statement;
        SelectBody selectBody = select.getSelectBody();
        PlainSelect plainSelect = (PlainSelect) selectBody;
        Table targetTable = (Table) plainSelect.getFromItem();

        //当SQL节点之前有多个 父节点的时候， 需要查找对应的inputTable
        int targetTableIndex = 0;
        int unmatchedNum = 0;
        for (int i = 0; i < input.size(); i++) {
            JSONObject inputItemObj = (JSONObject) input.get(i);
            String tempTableName = getTempTableName(inputItemObj, (Long) parentTimeStamps.get(i));
            if (compareTableName(targetTable, tempTableName)) {
                targetTableIndex = i;
            } else {
                unmatchedNum += 1;
            }
        }
        //如果使用的数据表不在连接的节点范围内，抛出异常
//        if (targetTableIndex == 0 && unmatchedNum == input.size()) {
//            logger.warn("SQL node is trying to access invalid data table");
//            throw new SqlParserException("当前操作正在请求非法数据");
//        }
        List<SelectItem> selectItems = plainSelect.getSelectItems();
        List<String> wrappedItems = new ArrayList<>();
        //对型号* 全选做特别处理，查出所有字段名
        for (SelectItem item : selectItems) {
            try {
                SelectExpressionItem expressionItem = (SelectExpressionItem) item;
                Alias itemAlias = expressionItem.getAlias();
                String columnName = extractColumnName(expressionItem);
                if (itemAlias == null) {
                    itemAlias = new Alias(columnName, true);
                    wrappedItems.add(columnName + " AS \"" + itemAlias.getName() + "\"");
                } else {
                    wrappedItems.add(columnName);
                }
                outputColumns.add(itemAlias.getName());
            } catch (Exception e) {
                if (item.getASTNode().jjtGetLastToken().toString().equals("*")) {
                    outputColumns.add("*");
                }
            }
        }

        int start = 7;
        int end = sql.toLowerCase().indexOf("from");
        if (outputColumns.size() == 1 && outputColumns.get(0).equals("*")) {
            outputColumns.remove("*");

            List<String> tableCols = input.getJSONObject(targetTableIndex).getJSONArray("tableCols")
                    .toJavaList(String.class);
            outputColumns.addAll(tableCols);
            if (!outputColumns.contains(ID_COL)) {
                outputColumns.add(ID_COL);
                Collections.reverse(outputColumns);
                return String.format("%s row_number() over() as _record_id_, %s %s",
                        sql.substring(0, start), Joiner.on(",").join(tableCols),
                        sql.substring(end));
            } else {
                Collections.reverse(outputColumns);
                return String.format("%s %s %s", sql.substring(0, start),
                        Joiner.on(",").join(outputColumns), sql.substring(end));
            }
        }

        if (!outputColumns.contains(ID_COL)) {
            outputColumns.add(ID_COL);
            Collections.reverse(outputColumns);
            return String
                    .format("select row_number() over() as _record_id_,%s %s ", StringUtils.join(wrappedItems, ","), sql.substring(end));
        }
        return sql;
    }

    public String initSql(JSONObject conf, List<SqlHelper> sqlHelpers, long timeStamp,
                          String engineName) {
        this.engineName = engineName;
        String outTable = "";
        String sql = "";
        String modifySql = "";
        if (conf.containsKey("adjustSql")) {
            modifySql = conf.getString("adjustSql");
        }
        sql = this.replaceTableName(timeStamp, modifySql);
        JSONArray tableCols = conf.getJSONArray("input").getJSONObject(0).getJSONArray("tableCols");
        sql = this.wrapperNumColName(tableCols, sql);
        if (StringUtils.isEmpty(sql)) {
            return sql;
        }
        JSONArray output = conf.getJSONArray("output");
        outTable = output.getJSONObject(0).getString("tableName") + timeStamp;
        return String.format(SqlTemplate.CREATE_VIEW_SQL, outTable, sql);
    }

    private String wrapperNumColName(JSONArray tableCols, String sqlStr) {
        String FormatedSql = sqlStr.replaceAll("\uFEFF", "").replaceAll("、", "");
        if (StringUtils.isNotEmpty(sqlStr)) {
            List<String> outputColumns = Lists.newArrayList();
            Statement statement = null;
            try {
                statement = CCJSqlParserUtil.parse(FormatedSql);
            } catch (JSQLParserException e) {
                logger.error("something wrong happened, when parse SQL -> {}", sqlStr);
                return "";
            }
            Select select = (Select) statement;
            SelectBody selectBody = select.getSelectBody();
            PlainSelect plainSelect = (PlainSelect) selectBody;
            List<SelectItem> selectItems = plainSelect.getSelectItems();
            for (SelectItem item : selectItems) {
                SelectExpressionItem expressionItem = (SelectExpressionItem) item;
                String colName = expressionItem.getASTNode().jjtGetLastToken().toString();
                if (StringUtil.isAllNumeric(colName)) {
                    colName = "\"" + colName + "\"";
                } else {
                    colName = expressionItem.getExpression().toString();
                    if (null != expressionItem.getAlias()) {
                        colName += expressionItem.getAlias().toString();
                    }
                }
                outputColumns.add(colName);
            }
            int end = sqlStr.toLowerCase().indexOf("from");
            return String.format("%s %s %s", sqlStr.substring(0, 7),
                    Joiner.on(",").join(outputColumns), sqlStr.substring(end));
        }
        return "";
    }

    public void defineOutput(TaskVO vo) throws SqlParserException {
        JSONObject jsonObject = vo.getData();
        String sqlStr = jsonObject.getString("sql");
        if (StringUtils.isEmpty(sqlStr)) {
            logger.warn("sql text is empty!!!!");
            return;
        }
        JSONArray output = new JSONArray();
        String tableName = String
                .format(SqlTemplate.VIEW_TABLE_NAME, vo.getPipelineId(), vo.getId());
        JSONObject item = new JSONObject();
        JSONArray outputColTypes = new JSONArray();
        List<String> outputColNames = new ArrayList<>();
        JSONArray input = jsonObject.getJSONArray("input");
        JSONArray parentTimeStamps = jsonObject.getJSONArray("parentTimeStamps");
        sqlStr = extractColumnsForSelectSql(sqlStr, outputColNames, input, vo, parentTimeStamps);
        jsonObject.put("adjustSql", sqlStr);
        List<String> inputColumnTypes = input.getJSONObject(0).getJSONArray("columnTypes")
                .toJavaList(String.class);
        if (outputColNames.size() == inputColumnTypes.size() + 1) {
            // 默认colType 的数量等于 colName
            // 默认添加 _record_id_ 的类型
            inputColumnTypes.add(0, "BIGINT");
            outputColTypes.addAll(inputColumnTypes);
        } else {
            this.prepareOutputColumnTypes(outputColTypes, outputColNames,
                    input.getJSONObject(0).getJSONArray("tableCols").toJavaList(String.class),
                    inputColumnTypes);
        }

        JSONObject numberFormat = input.getJSONObject(0).getJSONObject("numberFormat");
        if (numberFormat != null && outputColNames != null && !outputColNames.isEmpty()) {
            Set<String> cols = numberFormat.keySet();
            for (String col : cols) {
                if (!outputColNames.contains(col)) {
                    numberFormat.remove(col);
                }
            }
        }
        item.put("numberFormat", numberFormat);

        item.put("tableName", tableName);
        item.put("tableCols", outputColNames);
        item.put("nodeName", vo.getName() == null ? ETLEnum.SQL.toString() : vo.getName());
        item.put("columnTypes", outputColTypes);
        this.setSubTypeForOutput(item);
        output.add(item);
        jsonObject.put("output", output);
        vo.setData(jsonObject);
    }

    public void initTemplate(JSONObject data) {
        data.put("sql", "");
        baseInitTemplate(data);
    }

    public boolean verify(TaskVO vo, List<ApiResultCode> errorCode) {
        JSONObject jsonObject = vo.getData();
        if (jsonObject == null) {
            errorCode.add(ApiResultCode.SYS_ERROR);
            return false;
        }
        String sql = jsonObject.getString("sql");
        if (StringUtils.isEmpty(sql)) {
            return true;
        }
        try {
            Statement statement = CCJSqlParserUtil.parse(sql);
            if (!(statement instanceof Select)) {
                logger.error("just support select sql");
                errorCode.add(ApiResultCode.SQL_NOT_SUPPORT);
                return false;
            }
            if (StringUtils.isEmpty(vo.getParentId())) {
                errorCode.add(ApiResultCode.SQL_NOT_PARENT);
                return false;
            }
            if (!sql.contains("$")) {
                errorCode.add(ApiResultCode.SQL_TABLE_PERMISSION);
                return false;
            }
        } catch (JSQLParserException e) {
            logger.error(e.getMessage());
            errorCode.add(ApiResultCode.SYS_ERROR);
            return false;
        }

        return true;
    }
}